import torch
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel
import torch.nn as nn

class BCAUSS(AbstractModel):
    def __init__(self, config,dataset):
        super(BCAUSS, self).__init__(config,dataset)
        self.in_feature = self.dataset.size[1]
        self.alpha = self.config['alpha']
        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']


        self.repre_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))

        self.pred_layers_treated = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))

        self.pred_layers_treated.add_module('out1',nn.Linear(self.pred_layer_sizes[-1],1))

        self.pred_layers_control = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.pred_layers_control.add_module('out0', nn.Linear(self.pred_layer_sizes[-1],1))


        self.pred_layers_propensity = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.pred_layers_propensity.add_module('out2', nn.Linear(self.pred_layer_sizes[-1], 1))

        if self.loss_type == 'MSE':
            self.loss_fct = nn.MSELoss(reduction='sum')
        elif self.loss_type == 'CE':
            self.loss_fct = nn.BCEWithLogitsLoss(reduction='sum')
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['MSE', 'CE']!")
        self.regu_loss = nn.MSELoss(reduction='sum')


    def forward(self, x):
        self.repre = self.repre_layers(x)
        treat_output = self.pred_layers_treated(self.repre)
        control_output = self.pred_layers_control(self.repre)
        regu_output = torch.sigmoid(self.pred_layers_propensity(self.repre))
        return treat_output, control_output,regu_output

    def get_repre(self, x, device):
        self.eval()
        with torch.no_grad():
            return self.repre_layers.to(device)(x.to(device))
    def regularization_loss(self,x,t,regu_output):
        # regu_output = regu_output + 1e-10
        regu_output = (regu_output + 0.001) / 1.002
        treat_cov = torch.sum((t/regu_output) * x,dim=0)
        treat_nocov = torch.sum(t/regu_output)
        control_cov = torch.sum(((1-t)/(1-regu_output)) * x,dim=0)
        control_nocov = torch.sum((1-t)/(1-regu_output))
        treat_loss =  treat_cov / treat_nocov
        control_loss = control_cov / control_nocov

        return self.regu_loss(treat_loss, control_loss)


    def calculate_loss(self, x,t,y,w):
        treat_output, control_output,regu_output = self.forward(x)
        pred = torch.where(t == 1, treat_output, control_output)
        mse_loss = self.loss_fct(pred,y)
        regu_loss = self.regularization_loss(x,t,regu_output)
        # print(regu_loss)
        loss = mse_loss + self.alpha * regu_loss
        return loss

    def predict(self, x,t):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        treat_output, control_output,_ = self.forward(x)
        y = torch.where(t == 1, treat_output,control_output)
        if self.loss_type == 'MSE':
            return y
        else:
            return torch.sigmoid(y)